"""CoH2 basic plots
v 1.0
22.06.2020
Hannibal"""

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math
import scipy.stats


class Vehicle:
    def __init__(self, name, class_, faction, hitpoints, damage, deflection_dmg_modifier, target_size, armor_front, armor_rear, pen_near, pen_mid, pen_far, range_,
                 range_refdis_near, range_refdis_mid, range_refdis_far, acc_near, acc_mid, acc_far, ready_aim_time_min, ready_aim_time_max, wind_up,
                 wind_down, fire_aim_time_min, fire_aim_time_max, reload_min, reload_max, scatter_offset, scatter_ratio, distance_scatter_max, scatter_angle,
                 aoe_dmg_near, aoe_dmg_mid, aoe_dmg_far, aoe_dis_near, aoe_dis_mid, aoe_dis_far, aoe_dis_max, hitbox_length, hitbox_broadness):
        self.name = name
        self.class_ = class_
        self.faction = faction
        #self.Manpower = manpower
        #self.fuel = fuel
        #self.pop = pop
        self.armor_front = armor_front
        self.hitpoints = hitpoints
        self.armor_rear = armor_rear
        self.target_size = target_size
        #self.speed = speed
        self.acc_far = acc_far
        self.acc_mid = acc_mid
        self.acc_near = acc_near
        self.damage = damage
        self.deflection_damage_modifier = deflection_dmg_modifier
        self.pen_far = pen_far
        self.pen_mid = pen_mid
        self.pen_near = pen_near
        self.range = range_
        #self.sight = sight
        self.reload_min = reload_min
        self.reload_max = reload_max
        self.range_refdis_near = range_refdis_near
        self.range_refdis_mid = range_refdis_mid
        self.range_refdis_far = range_refdis_far
        self.scatter_offset = scatter_offset
        self.scatter_ratio = scatter_ratio
        self.distance_scatter_max = distance_scatter_max
        self.scatter_angle = scatter_angle
        self.hitbox_length = hitbox_length
        self.hitbox_broadness = hitbox_broadness
        self.ready_aim_time_min = ready_aim_time_min
        self.ready_aim_time_max = ready_aim_time_max
        self.wind_up = wind_up
        self.wind_down = wind_down
        self.fire_aim_time_min = fire_aim_time_min
        self.fire_aim_time_max = fire_aim_time_max

        self.aoe_dmg_near = aoe_dmg_near
        self.aoe_dmg_mid = aoe_dmg_mid
        self.aoe_dmg_far = aoe_dmg_far
        self.aoe_dis_near = aoe_dis_near
        self.aoe_dis_mid = aoe_dis_mid
        self.aoe_dis_far = aoe_dis_far
        self.aoe_dis_max = aoe_dis_max

    def calc_pen(self, distance):
        """Calculates the penetration at given range. Logic: Takes penetration of the neighbouring reference points
                and weights them according to the actual range"""

        penetration = None

        if distance > self.range:
            penetration = 0
        elif distance <= self.range_refdis_near:
            penetration = self.pen_near
        elif distance <= self.range_refdis_mid:
            penetration = self.pen_near * (
                        1 - (distance - self.range_refdis_near)
                        / (self.range_refdis_mid - self.range_refdis_near)) + self.pen_mid * (
                        (distance - self.range_refdis_near) / (self.range_refdis_mid - self.range_refdis_near))
        elif distance <= self.range_refdis_far:
            penetration = self.pen_mid * (1 - (distance - self.range_refdis_mid)
                          / (self.range_refdis_far - self.range_refdis_mid)) + self.pen_far \
                          * ((distance - self.range_refdis_mid) / (self.range_refdis_far - self.range_refdis_mid))
        elif distance <= self.range:
            penetration = self.pen_far

        return penetration

    def calc_aoe_damage(self, distance_):
        """Calculates AoE damage on the target."""

        dmg = ""  # makes sure to raise error if something is wrong
        if distance_ > self.aoe_dis_max:
            dmg = 0
        elif distance_ <= self.aoe_dis_near:
            dmg = self.damage * self.aoe_dmg_near
        elif distance_ <= self.aoe_dis_mid:
            dmg = self.damage * self.aoe_dmg_near \
                  + self.damage * (distance_ - self.aoe_dis_near) \
                  * (self.aoe_dmg_mid - self.aoe_dmg_near) / (self.aoe_dis_mid - self.aoe_dis_near)
        elif distance_ <= self.aoe_dis_far:
            dmg = self.damage * self.aoe_dmg_mid \
                  + self.damage * (distance_ - self.aoe_dis_mid) \
                  * (self.aoe_dmg_far - self.aoe_dmg_mid) / (self.aoe_dis_far - self.aoe_dis_mid)
        elif distance_ <= self.aoe_dis_max:
            dmg = self.damage * self.aoe_dmg_far
        return dmg

    def calc_base_acc(self, distance):
        """Calculates the base accuracy at given range. Must me multiplied with the target size. Logic:
                Takes accuracy of the neighbouring reference points and weights them according to the actual range"""

        base_accuracy = None

        if distance > self.range:
            base_accuracy = 0
        elif distance <= self.aoe_dis_near:
            base_accuracy = self.acc_near
        elif distance <= self.range_refdis_mid:
            base_accuracy = self.acc_near * (
                        1 - (distance - self.range_refdis_near) / (self.range_refdis_mid - self.range_refdis_near)) + self.acc_mid * (
                                    (distance - self.range_refdis_near) / (self.range_refdis_mid - self.range_refdis_near))
        elif distance <= self.range_refdis_far:
            base_accuracy = self.acc_mid * (
                         1 - (distance - self.range_refdis_mid) / (self.range_refdis_far - self.range_refdis_mid)) + self.acc_far * (
                                    (distance - self.range_refdis_mid) / (self.range_refdis_far - self.range_refdis_mid))
        elif distance <= self.range:
            base_accuracy = self.acc_far
        return base_accuracy

    def calc_avg_time_to_next_shot(self):
        """Calculates the AVERAGE time it takes to shoot based on the ROF stats. Only important for the inspector."""

        tbs_correction_value = 0

        for i in self.wind_down, self.wind_up, self.fire_aim_time_max, self.reload_max:
            tbs_correction_value += 0.125 if i != 0 else 0

        self.average_time_to_shoot = self.wind_up + self.wind_down + (self.fire_aim_time_min + self.fire_aim_time_max) / 2\
                                     + (self.reload_min + self.reload_max) / 2 + tbs_correction_value

    def calc_scatter_area(self, distance_):
        """Generates the coordinates and angle of a scattered shot."""

        scatter_offset = self.scatter_offset
        scatter_ratio = self.scatter_ratio
        distance_scatter_max = self.distance_scatter_max
        scatter_angle = self.scatter_angle

        # calculate longitudinal scatter
        long_scatter = min(distance_scatter_max, distance_ * scatter_ratio)
        scatter_centre = distance_ + scatter_offset * long_scatter
        long_scatter_min = scatter_centre - long_scatter
        long_scatter_max = scatter_centre + long_scatter

        effective_angle = scatter_angle / 2  # halving is necessary due to instructions
        angle_scatter_rad = math.radians(effective_angle)

        x_coordinate_short_corners = math.cos(angle_scatter_rad) * long_scatter_min
        y_coordinate_short_corners = math.sin(angle_scatter_rad) * long_scatter_min
        x_coordinate_far_corners = math.cos(angle_scatter_rad) * long_scatter_max
        y_coordinate_far_corners = math.sin(angle_scatter_rad) * long_scatter_max

        corner_min_l = (x_coordinate_short_corners, y_coordinate_short_corners)
        corner_min_r = (x_coordinate_short_corners, -y_coordinate_short_corners)
        corner_max_l = (x_coordinate_far_corners, y_coordinate_far_corners)
        corner_max_r = (x_coordinate_far_corners, -y_coordinate_far_corners)

        scatter_area = ((long_scatter_max ** 2 * math.pi) - (long_scatter_min ** 2 * math.pi)) * scatter_angle / 360

        return corner_min_l, corner_min_r, corner_max_l, corner_max_r, long_scatter_min, long_scatter_max, scatter_area


def read_data_file():
    """Reads in Excel file as pd.DataFrame. Uses names as index and drops unnecessary columns."""
    veh_data_ = pd.read_excel("Unit_Data.xlsx", index_col="ID")
    veh_data_ = veh_data_.iloc[1:, :]
    veh_data_.drop(columns=["Reload_avg", "BurstLength", "Burst_multi_near", "Burst_multi_mid", "Burst_multi_far", "Burst_ROF", "CD_multi_near",
                            "CD_multi_mid", "CD_multi_far", "CD_min", "CD_max", "CD_avg", "Delay", "Shots/Reload", "ScatterArea",
                            "Radius", "Shots/Salvo", "Targeted", "Recharge", "Recharge_early", "Subsets", "TBox_X",
                            "TBox_Y", "TBox_subs", "Active_sub"], inplace=True)

    return veh_data_


def read_user_input():

    vehicles = []

    with open("vehicles.txt") as f:
        for line in f:
            vehicles += line.split(",")

    for i in range(len(vehicles)):
        vehicles[i] = vehicles[i].strip()

    print("Recognized vehicles:")
    for entry in vehicles:
        print(entry)

    return vehicles


def build_unit(df, veh_index):

    vehicle = df.loc[veh_index]
    return_vehicle = Vehicle(veh_index, vehicle["Class"], vehicle["Faction"], vehicle["Health"], vehicle["Damage"], vehicle["Deflection_multi"], vehicle["Target_size"],
                             vehicle["Armor_front"], vehicle["Armor_rear"], vehicle["Pen_near"], vehicle["Pen_mid"], vehicle["Pen_far"],
                             vehicle["Range_max"], vehicle["Range_refdis_near"], vehicle["Range_refdis_mid"],
                             vehicle["Range_refdis_far"], vehicle["Accuracy_near"], vehicle["Accuracy_mid"], vehicle["Accuracy_far"],
                             vehicle["Ready_min"], vehicle["Ready_max"], vehicle["Windup"], vehicle["Winddown"],
                             vehicle["Fire_aim_min"], vehicle["Fire_aim_max"], vehicle["Reload_min"], vehicle["Reload_max"], vehicle["Scatter_offset"],
                             vehicle["Scatter_ratio"], vehicle["Dis_scatter_max"], vehicle["Scatter_angle"],
                             vehicle["AoE_dmg_near"], vehicle["AoE_dmg_mid"], vehicle["AoE_dmg_far"],
                             vehicle["AoE_dis_near"], vehicle["AoE_dis_mid"], vehicle["AoE_dis_far"],
                             vehicle["AoE_dis_max"], vehicle["Hitbox_length"], vehicle["Hitbox_broadness"])
    return return_vehicle


def pen_damage_plot(target, shooter):

    distances = np.arange(0, shooter.range + 0.02, 0.01)
    pen_chances = []
    natural_hit_chance = []

    fig, ax = plt.subplots(2, 1, figsize=(9, 6))

    ax[0].set_title("{} shooting at {}".format(shooter.name, target.name))
    ax[0].set_xlabel("distance [m]")
    ax[0].set_ylabel("chance")
    ax[0].set_ylim(0, 1.1)

    for dis in distances:
        pen_chances += [min(1, shooter.calc_pen(dis) / target.armor_front)]

    for dis in distances:
        natural_hit_chance += [min(1, shooter.calc_base_acc(dis) * target.target_size)]

    pen_damage_chance = [min(1, pen * hit) for pen, hit in zip(pen_chances, natural_hit_chance)]

    pen_chance_max_range = shooter.calc_pen(shooter.range) / target.armor_front
    pen_chance_10_range = shooter.calc_pen(10) / target.armor_front
    pens_to_kill = math.ceil(target.hitpoints / shooter.damage)
    no_of_shots = np.arange(math.ceil(pens_to_kill / pen_chance_max_range * 2.5))
    expected_shots_to_kill_max_range = pens_to_kill / pen_chance_max_range
    expected_shots_to_kill_10_range = pens_to_kill / pen_chance_10_range

    ax[0].plot(distances, pen_chances, distances, natural_hit_chance, distances, pen_damage_chance)
    ax[0].legend(["penetration chance", "natural hit chance", "damage chance"], bbox_to_anchor=(1, 1), loc="upper left")
    ax[1].plot(no_of_shots, scipy.stats.binom.pmf(pens_to_kill, no_of_shots, pen_chance_max_range), 'b-')
    ax[1].plot(no_of_shots, scipy.stats.binom.pmf(pens_to_kill, no_of_shots, pen_chance_10_range), 'c:')
    max_y_of_plot = ax[1].get_ylim()[1]
    ax[1].vlines(expected_shots_to_kill_max_range, 0, max_y_of_plot, color='r')
    ax[1].vlines(expected_shots_to_kill_10_range, 0, max_y_of_plot, color='m', linestyles=":")
    ax[1].legend(["chance for {} penetrations\nout of x shots at range {}".format(pens_to_kill, shooter.range),
                  "chance for {} penetrations\nout of x shots at range 10".format(pens_to_kill),
                  "expected penetrations to kill\nat range {}".format(shooter.range),
                  "expected penetrations to kill\nat range 10"], bbox_to_anchor=(1, 1), loc="upper left")

    plt.show()


def pen_plots(target, *shooters):
    legend = []

    for vehicle in shooters:
        distances = np.arange(0, vehicle.range + 0.02, 0.01)
        pen_chances = []

        for dis in distances:
            pen_chances += [min(1, vehicle.calc_pen(dis) / target.armor_front)]
        plt.plot(distances, pen_chances)
        legend += [vehicle.name]

    plt.title("penetration chance vs {}".format(target.name))
    plt.xlabel("distance [m]")
    plt.ylabel("chance")
    plt.ylim(0, 1.1)
    plt.legend(legend, bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()


def aoe_plot(*vehicles):
    legend = ["80 HP threshold"]

    for vehicle in vehicles:
        distances = np.arange(0, vehicle.aoe_dis_max + 0.002, 0.001)
        aoe_dmgs = []

        for dis in distances:
            aoe_dmgs += [vehicle.calc_aoe_damage(dis)]
        plt.plot(distances, aoe_dmgs)
        legend.insert(-1, vehicle.name)

    plt.title("AoE damage")
    plt.xlabel("distance [m]")
    plt.ylabel("damage")
    plt.ylim(0)
    max_x_of_plot = plt.gca().get_xlim()[1]
    plt.hlines(80, 0, max_x_of_plot, color="r", label="80 HP threshold")
    plt.legend(legend, bbox_to_anchor=(1, 1), loc="upper left")

    plt.show()


def acc_plots(*vehicles):
    legend = []

    for vehicle in vehicles:
        distances = np.arange(0, vehicle.range + 0.02, 0.01)
        accuracies = []

        for dis in distances:
            accuracies += [vehicle.calc_base_acc(dis)]
        plt.plot(distances, accuracies)
        legend += [vehicle.name]

    plt.title("base accuracy profile")
    plt.xlabel("distance [m]")
    plt.ylabel("base accuracy")
    plt.ylim(0)
    plt.legend(legend, bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()


def scatter_areas_plots(*vehicles):
    legend = []

    for vehicle in vehicles:
        distances = np.arange(0, vehicle.range + 0.01, 0.01)
        areas = []

        for dis in distances:
            *_, area = vehicle.calc_scatter_area(dis)
            areas += [area]
        plt.plot(distances, areas)
        legend += [vehicle.name]

    plt.title("scatter area")
    plt.xlabel("distance [m]")
    plt.ylabel("area [m^2]")
    plt.ylim(0)
    plt.legend(legend, bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()


print("""CoH2 basic plots
v 1.0
22.06.2020
Hannibal \n\n""")

plt.rcParams.update({'figure.autolayout': True})

data = read_data_file()
vehicle_input = read_user_input()

vehicle_list = []
for entry in vehicle_input:
    vehicle_list += [build_unit(data, entry)]


aoe_plot(*vehicle_list)
scatter_areas_plots(*vehicle_list)
acc_plots(*vehicle_list)

if len(vehicle_list) > 1:
    pen_damage_plot(vehicle_list[0], vehicle_list[1])
    pen_plots(vehicle_list[0], *vehicle_list)
